from typing import Optional, Dict
import numpy as np
import torch
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, f1_score
import pandas as pd
from sklearn import metrics
from utils.logger import Logger
logger = Logger.logger


def compute_precisions(
    predictions: torch.Tensor,
    targets: torch.Tensor,
    src_lengths: Optional[torch.Tensor] = None,
    minsep: int = 6,
    maxsep: Optional[int] = None,
    override_length: Optional[int] = None,  
):
    if isinstance(predictions, np.ndarray):
        predictions = torch.from_numpy(predictions)
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    if predictions.dim() == 2:
        predictions = predictions.unsqueeze(0)
    if targets.dim() == 2:
        targets = targets.unsqueeze(0)
    override_length = (targets[0, 0] >= 0).sum()

    # Check sizes
    if predictions.size() != targets.size():
        raise ValueError(
            f"Size mismatch. Received predictions of size {predictions.size()}, "
            f"targets of size {targets.size()}"
        )
    device = predictions.device

    batch_size, seqlen, _ = predictions.size()
    seqlen_range = torch.arange(seqlen, device=device)

    sep = seqlen_range.unsqueeze(0) - seqlen_range.unsqueeze(1)
    sep = sep.unsqueeze(0)
    valid_mask = sep >= minsep
    valid_mask = valid_mask & (targets >= 0)  # negative targets are invalid

    if maxsep is not None:
        valid_mask &= sep < maxsep

    if src_lengths is not None:
        valid = seqlen_range.unsqueeze(0) < src_lengths.unsqueeze(1)
        valid_mask &= valid.unsqueeze(1) & valid.unsqueeze(2)
    else:
        src_lengths = torch.full([batch_size], seqlen, device=device, dtype=torch.long)

    predictions = predictions.masked_fill(~valid_mask, float("-inf"))

    x_ind, y_ind = np.triu_indices(seqlen, minsep)
    predictions_upper = predictions[:, x_ind, y_ind]
    targets_upper = targets[:, x_ind, y_ind]

    topk = seqlen if override_length is None else max(seqlen, override_length)
    indices = predictions_upper.argsort(dim=-1, descending=True)[:, :topk]
    topk_targets = targets_upper[torch.arange(batch_size).unsqueeze(1), indices]
    if topk_targets.size(1) < topk:
        topk_targets = F.pad(topk_targets, [0, topk - topk_targets.size(1)])

    cumulative_dist = topk_targets.type_as(predictions).cumsum(-1)

    gather_lengths = src_lengths.unsqueeze(1)
    if override_length is not None:
        gather_lengths = override_length * torch.ones_like(
            gather_lengths, device=device
        )

    gather_indices = (
        torch.arange(0.1, 1.1, 0.1, device=device).unsqueeze(0) * gather_lengths
    ).type(torch.long) - 1

    binned_cumulative_dist = cumulative_dist.gather(1, gather_indices)
    binned_precisions = binned_cumulative_dist / (gather_indices + 1).type_as(
        binned_cumulative_dist
    )

    pl5 = binned_precisions[:, 1]
    pl2 = binned_precisions[:, 4]
    pl = binned_precisions[:, 9]
    auc = binned_precisions.mean(-1)

    return {"AUC": auc, "P@L": pl, "P@L2": pl2, "P@L5": pl5}


# compute contact precisions
def evaluate_prediction(
    predictions: torch.Tensor,
    targets: torch.Tensor,
) -> Dict[str, float]:
    if isinstance(targets, np.ndarray):
        targets = torch.from_numpy(targets)
    contact_ranges = [
        ("local", 3, 6),
        ("short", 6, 12),
        ("medium", 12, 24),
        ("long", 24, None),
    ]
    metrics = {}
    targets = targets.to(predictions.device)
    for name, minsep, maxsep in contact_ranges:
        rangemetrics = compute_precisions(
            predictions,
            targets,
            minsep=minsep,
            maxsep=maxsep,
        )
        for key, val in rangemetrics.items():
            metrics[f"{name}_{key}"] = val.item()
    return metrics

"""
Adapted from https://github.com/BioinfoMachineLearning/CDPred/blob/main/lib/distmap_evaluate.py
"""
def ceil_topxL_to_one(Y_hat, x):
    Y_ceiled = np.copy(Y_hat)
    xL = int(x)
    Y_ceiled[:] = np.zeros(len(Y_hat[:]))
    Y_ceiled[np.argsort(Y_hat)[-xL:]] = 1
    return Y_ceiled.astype(int)


def EvaluateComplex(true_complex_vec, pred_complex_vec, tar_length):
    """
    Args:
        true_complex_vec: 1d array-like
        pred_complex_vec: 1d array-like
    """
    top5_prec_com   = metrics.precision_score(true_complex_vec, ceil_topxL_to_one(pred_complex_vec, 5))
    top10_prec_com  = metrics.precision_score(true_complex_vec, ceil_topxL_to_one(pred_complex_vec, 10)) 
    topL10_prec_com = metrics.precision_score(true_complex_vec, ceil_topxL_to_one(pred_complex_vec, tar_length/10))
    topL5_prec_com = metrics.precision_score(true_complex_vec, ceil_topxL_to_one(pred_complex_vec, tar_length/5))
    topL2_prec_com = metrics.precision_score(true_complex_vec, ceil_topxL_to_one(pred_complex_vec, tar_length/2)) 
    topL_prec_com = metrics.precision_score(true_complex_vec, ceil_topxL_to_one(pred_complex_vec, tar_length))
    return top5_prec_com, top10_prec_com, topL10_prec_com, topL5_prec_com, topL2_prec_com, topL_prec_com


def evaluate_inter_chain_prediction(inter_chain_pred, inter_chain_targ, short_len=None):
    """
    compute inter_chain contact prediction precision for a single example
    Args:
        inter_chain_pred: np.array of shape (len1, len2)
        inter_chain_targ: np.array of shape (len1, len2)
    """
    len1, len2 = inter_chain_pred.shape
    inter_chain_pred = inter_chain_pred.reshape(-1)
    inter_chain_targ = inter_chain_targ.reshape(-1)
    if short_len is None:
        short_len = min(len1, len2)
    top5_prec_com, top10_prec_com, topL10_prec_com, topL5_prec_com, topL2_prec_com, topL_prec_com= EvaluateComplex(inter_chain_targ, inter_chain_pred, short_len)
    result = {'top5': top5_prec_com, 'top10': top10_prec_com, 
            'topL10': topL10_prec_com, 'topL5': topL5_prec_com, 
            'topL2': topL2_prec_com, 'topL': topL_prec_com}
    return result


def distogram_probs_to_contact_probs(disto_probs, upper_bound=19):
    """ From distogram to contact
    define cbcb_dis < 8.25 as contact --> dis_bins = [0,...,19] 
    for 64 bins: upper_bound = 19
    for 18 bins: upper_bound = 5
    """
    contact_probs = torch.sum(disto_probs[:, :, 0:upper_bound+1], -1)
    return contact_probs


def compute_inter_chain_contact_precision_from_distogram(disto_logits, disto_labels, lengths, upper_bound=19):
    """
    Args:
        disto_logits: dict, {name: tensor (seqlen, seqlen, num_bins)}
        disto_labels: dict, {name: np.array of shape (seqlen, seqlen)}, labels are symmetric matrix, bin_index starts from 1
        lengths: dict, {name: [len1, len2]}
        upper_bound: number of bins that belongs to contact 
    """
    topL_contact_precision_list = []
    for name, logits in disto_logits.items():
        disto_probs = torch.nn.functional.softmax(logits , dim=-1)  # (seqlen, seqlen, num_bins)
        contact_probs = distogram_probs_to_contact_probs(disto_probs, upper_bound)
        contact_labels = (disto_labels[name]-1) <= upper_bound
        # inter-chain contact prediction
        len1 = lengths[name][0]
        inter_chain_probs = contact_probs[:len1, len1:].numpy()
        inter_chain_labels = contact_labels[:len1, len1:]
        topL_contact_prec = evaluate_inter_chain_prediction(inter_chain_probs, inter_chain_labels)['topL5']
        topL_contact_precision_list.append(topL_contact_prec)
    return np.mean(topL_contact_precision_list)


def evaulate_contact_predicsion_from_distogram(disto_logits, disto_labels, lengths, disto_loss, upper_bound=19):
    """
    Args:
        disto_logits: dict, {name: tensor (seqlen, seqlen, num_bins)}
        disto_loss: dict, {name: scalar}
        disto_labels: dict, {name: np.array of shape (seqlen, seqlen)}, labels are symmetric matrix, bin_index starts from 1
        lengths: dict, {name: [len1, len2]}
        upper_bound: number of bins of distogram that belongs to contact 
    """
    results = []
    whole_seq_contact_results = []
    inter_chain_contact_results = []
    logloss = 0
    for name, logits in disto_logits.items():
        disto_probs = torch.nn.functional.softmax(logits , dim=-1)  # (seqlen, seqlen, 64)
        # For each residue pairs (i,j), take the bin with maximum probability as the predicted label 
        pred_labels = torch.argmax(disto_probs, dim=-1).numpy()  # (seqlen, seqlen), symmetric matrix
        # only consider upper triangle
        n = pred_labels.shape[0]
        pred_labels = pred_labels[np.triu_indices(n)]       # 1-d array
        labels = (disto_labels[name]-1)[np.triu_indices(n)] # 1-d array
        accuracy = accuracy_score(labels, pred_labels)
        macro_f1 = f1_score(labels, pred_labels, average='macro')
        micro_f1 = f1_score(labels, pred_labels, average='micro')
        if disto_loss is not None:
            logloss = disto_loss[name].numpy()
        results.append({'id': name, 'acc': accuracy, 'macro_f1': macro_f1, 'micro_f1': micro_f1, 'logloss': logloss})
        
        # whole sequence contact prediction
        contact_probs = distogram_probs_to_contact_probs(disto_probs, upper_bound)
        contact_labels = (disto_labels[name]-1) <= upper_bound
        metrics = {"id": name}
        metrics.update(evaluate_prediction(contact_probs, contact_labels))
        whole_seq_contact_results.append(metrics)

        # inter-chain contact prediction
        metrics1 = {"id": name}
        len1, len2 = lengths[name]
        inter_chain_probs = contact_probs[:len1, len1:].numpy()
        inter_chain_labels = contact_labels[:len1, len1:]
        metrics1.update(evaluate_inter_chain_prediction(inter_chain_probs, inter_chain_labels))
        inter_chain_contact_results.append(metrics1)
        
    results = pd.DataFrame(results)
    inter_chain_contact_results = pd.DataFrame(inter_chain_contact_results)
    whole_seq_contact_results = pd.DataFrame(whole_seq_contact_results)

    # logger.info('distogram metrics:\n{}'.format(results.head()))
    logger.info('\n{}'.format(results.mean())) 
    # logger.info('inter-chain contact metrics:\n{}'.format(inter_chain_contact_results.head()))
    logger.info('\n{}'.format(inter_chain_contact_results.mean())) 
    # logger.info('whole sequence contact metrics:\n{}'.format(whole_seq_contact_results.head()))
    logger.info('\n{}'.format(whole_seq_contact_results.mean())) 

    